#! /usr/bin/env python
## This file is part of biopy.
## Copyright (C) 2013 Joseph Heled
## Author: Joseph Heled <jheled@gmail.com>
## See the files gpl.txt and lgpl.txt for copying conditions.

from __future__ import division
from biopy.parseNewick import parseNewick as pn
from biopy.treeutils import *

import argparse, sys, os.path

parser = argparse.ArgumentParser(description= """%(prog)s generates tikz code for the
input tree. Quite primitive at the moment, only special feature is a horizontal line
across the tree crossing a particular node.
""")

parser.add_argument("-t", "--tips-distance", dest = "dtips", type = float, metavar = 'N',
                    help="""distance between tips (%(default)g)""", default = 1.0)

parser.add_argument("-o", "--global-options", dest="gopts", metavar = 'STR',
                    help="""comma separated.""", default = None)

parser.add_argument("-s", "--style", action="append", metavar = 'STR',
                    help="""'name=tikz options' format, options are comma separated.""",
                    default = None)

parser.add_argument("--standalone", action="store_true", help="""make output a""" 
                    """ (minimal) valid latex file.""", default = False)

parser.add_argument('tree', metavar='TREE', help="""tree text (NEWICK)""")

options = parser.parse_args()

targ = options.tree.strip()
if targ[0] != '(' :
  print >> sys.stderr, "Expecting a tree in NEWICK format."
  sys.exit(1)

try :
  tree = pn(targ)
except Exception,e:
  print >> sys.stderr, "*** Error: failed to parse tree: ",e.message,"[",targ,"]"
  sys.exit(1)

def setTerms(t,n,i,dtips) :
  if not n.succ :
    n.data.x = i
    n.data.c = [(i,n)]
    i += dtips
  else :
    n.data.c = []
    for c in n.succ:
      x = t.node(c)
      i = setTerms(t, x, i,dtips)
      n.data.c.extend(x.data.c)
    n.data.c = sorted(n.data.c)

  return i

def setX(t,n,p,nh) :
  if n.succ :
    h = nh[n.id]
    if p.succ[0] == n.id :
      # we are left
      i = n.data.c[0][0]
      dx = p.data.x - i
      dy = nh[p.id]
      n.data.x = i + (dx/dy) * h
    else :
      # we are right
      i = n.data.c[-1][0]
      dx = p.data.x - i
      dy = nh[p.id]
      n.data.x = i + (dx/dy)*  h
    for i in n.succ :
      setX(t,t.node(i),n,nh)
      

def treePoints(t, dtips = 1):
  r = t.node(t.root)
  i = setTerms(t, r, 0, dtips)
  r.data.x = (r.data.c[0][0] + r.data.c[-1][0])/2

  nh = nodeHeights(t)
  for i in r.succ:
    setX(t,t.node(i),r,nh)

  return nh


def geta(n,name) :
  if hasattr(n.data,"attributes") :
    return n.data.attributes.get(name)
  return None

def output(tree, dtips = 1) :
  nh = treePoints(tree, dtips)
  tips = sorted([[tree.node(c).data.x,tree.node(c)] for c in tree.get_terminals()])

  lx,lt = tips[0]
  print "\\node[tip] (c%d) at (%s,%s) {%s};" % (lt.id,lt.data.x,0,lt.data.taxon)
  for x,n in tips[1:]:
    print "\\node[tip] (c%d) at (%s,%s) {%s};" % (n.id,x,0,n.data.taxon)
    lt = n
    lx = x

  for c in tree.all_ids() :
    n = tree.node(c)
    if n.succ:
      print "\\node (c%d) at (%s,%s) {};" % (n.id, n.data.x,nh[n.id])

  shape = "circle"
  sz = "3pt"
  style = "internal"
  for c in tree.all_ids() :
    n = tree.node(c)
    if n.succ:
      print "\\draw[%s] (c%d) %s (%s);" % (geta(n,"style") or style, n.id, shape,sz)

  for c in tree.all_ids() :
    n = tree.node(c)
    if n.prev is not None:
      print "\\draw[line] (c%d) -- (c%d);" % (n.id,n.prev)

  for c in tree.all_ids() :
    n = tree.node(c)
    if n.succ and geta(n,'vline'):
      hl = geta(n,'vline')
      print "\\draw[%s] (%s,%s) -- (%s,%s);" % (hl, tips[0][0], nh[n.id],tips[-1][0],nh[n.id])

styles = {'tip' : "text centered",
          'internal' : "green",
          'line' : "draw,blue,thick",
          'hline'  : "densely dotted, thick"
          }

# scale=0.6,node distance = .5cm,auto,
if options.gopts :
  gopts = ',' + options.gopts
else :
  gopts = ""

if options.standalone :
  print """\\documentclass{article}

\\usepackage{tikz}
\\usetikzlibrary{shapes,positioning}

\\begin{document}
"""
  
print """\\begin{tikzpicture}[transform shape%s]""" % gopts
print

if options.style:
  for s in options.style:
    name,attrs = s.split('=')
    print """\\tikzstyle{%s} = [%s]""" % (name,attrs)
    if name in styles :
      del styles[name]

for name,attrs in styles.iteritems() :
  print """\\tikzstyle{%s} = [%s]""" % (name,attrs)

print

output(tree, dtips = options.dtips)

print """\\end{tikzpicture}"""

if options.standalone :
  print """
\\end{document}
"""
